{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a885cf5d-c525-4f5b-a8e4-f67d2f699909",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Copyright 2023 Google LLC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d891d022-8979-40d4-848f-ecb84c17f12c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Copyright 2023 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#      http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "540d8642-c203-471c-a66d-0d43aabb0706",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# StyleAligned over SD1.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d54ea7-f7ab-4548-9b10-ece87216dc18",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from diffusers import DDIMScheduler,StableDiffusionPipeline\n",
    "import torch\n",
    "import mediapy\n",
    "import sa_handler\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "522b14e7-9768-4eaa-8433-bf88acb244c4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False,\n",
    "                          set_alpha_to_one=False)\n",
    "pipeline = StableDiffusionPipeline.from_pretrained(\n",
    "    \"CompVis/stable-diffusion-v1-4\",\n",
    "    revision=\"fp16\",\n",
    "    scheduler=scheduler\n",
    ")\n",
    "pipeline = pipeline.to(\"cuda\")\n",
    "\n",
    "handler = sa_handler.Handler(pipeline)\n",
    "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,\n",
    "                                      share_layer_norm=True,\n",
    "                                      share_attention=True,\n",
    "                                      adain_queries=True,\n",
    "                                      adain_keys=True,\n",
    "                                      adain_values=False,\n",
    "                                     )\n",
    "\n",
    "handler.register(sa_args, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5db98c81-8b72-4fc7-8cd0-65eda17198e3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# run StyleAligned\n",
    "\n",
    "sets_of_prompts = [\n",
    "  \"a toy train. macro photo. 3d game asset\",\n",
    "  \"a toy airplane. macro photo. 3d game asset\",\n",
    "  \"a toy bicycle. macro photo. 3d game asset\",\n",
    "  \"a toy car. macro photo. 3d game asset\",\n",
    "  \"a toy boat. macro photo. 3d game asset\",\n",
    "]\n",
    "images = pipeline(sets_of_prompts, generator=None).images\n",
    "mediapy.show_images(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afbe3876-22d9-4735-89b9-d5b5c46aea5c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}